"""
Phase 9: Capital Flow Risk — The Risk Manager's Perspective
============================================================
Decompose CA/GDP into predictable (demographics + macro) and idiosyncratic
components, then test whether KAOPEN regime affects the *distribution* of
idiosyncratic shocks — variance, skewness, tails — not just the mean.

Output tables (6):
  1. residual_moments.md     — Distribution moments by KAOPEN tercile
  2. quantile_regression.md  — Z effects across quantiles (0.05–0.95)
  3. opening_risk_premium.md — Does KAOPEN widen the residual distribution?
  4. risk_matrix.md          — Demo stage × KAOPEN regime tail risk matrix
  5. var_vulnerability.md    — Country rankings by forward tail risk
  6. systemic_risk.md        — Portfolio concentration and closure counterfactuals
"""

import pandas as pd
import numpy as np
from pathlib import Path
import sys
import warnings
warnings.filterwarnings('ignore')
from scipy import stats as scipy_stats
import statsmodels.api as sm

PROJECT_DIR = Path(__file__).resolve().parent.parent
ROOT_DIR = PROJECT_DIR.parent
sys.path.insert(0, str(ROOT_DIR / "multilateral" / "src"))
from model import PanelGLS

DATA = PROJECT_DIR / "data" / "processed"
OUT_TABLES = PROJECT_DIR / "output" / "tables"
OUT_TABLES.mkdir(parents=True, exist_ok=True)

N_BOOTSTRAP = 500
np.random.seed(42)


def stars(p):
    if p < 0.01: return '***'
    if p < 0.05: return '**'
    if p < 0.1: return '*'
    return ''


def fmt(val, se, p):
    s = stars(p)
    return f"{val:.4f}{s}", f"({se:.4f})"


# ══════════════════════════════════════════════════════════════════════════
# Step 1: Baseline Residual Extraction
# ══════════════════════════════════════════════════════════════════════════

def extract_residuals(df):
    """
    Fit PanelGLS: ca_gdp ~ Z_1 + Z_2 + Z_3 + controls (NO KAOPEN).
    Return DataFrame with resid, fitted, and metadata.
    """
    print("\n" + "=" * 70)
    print("STEP 1: BASELINE RESIDUAL EXTRACTION")
    print("=" * 70)

    x_vars = ['Z_1', 'Z_2', 'Z_3', 'fiscal_bal_gdp', 'nfa_gdp_lag', 'rgdp_growth']
    cols = ['ca_gdp'] + x_vars + ['iso3', 'year', 'kaopen', 'old_dep', 'youth_dep',
            'demo_tercile', 'gdp_pc_ppp', 'ca_usd', 'ngdp_usd']
    cols = [c for c in cols if c in df.columns]
    sub = df[cols].dropna(subset=['ca_gdp'] + x_vars + ['iso3', 'year']).copy()

    print(f"  Sample: {len(sub)} obs, {sub['iso3'].nunique()} countries")

    gls = PanelGLS()
    gls.fit(sub['ca_gdp'].values, sub[x_vars].values,
            sub['iso3'].values, sub['year'].values)

    print(f"  R² = {gls.r_squared:.4f}, ρ = {gls.rho:.4f}")
    for i, v in enumerate(x_vars):
        print(f"    {v:25s} {gls.beta[i]:8.4f} ({gls.se[i]:.4f}) {stars(gls.pvalues[i])}")

    sub['resid'] = gls.resid
    sub['fitted'] = gls.fitted

    # Sanity checks
    print(f"\n  Residual mean: {sub['resid'].mean():.6f} (should ≈ 0)")
    print(f"  Residual SD:   {sub['resid'].std():.4f}")
    print(f"  Residual min:  {sub['resid'].min():.4f}")
    print(f"  Residual max:  {sub['resid'].max():.4f}")

    return sub, gls, x_vars


# ══════════════════════════════════════════════════════════════════════════
# Step 2: Conditional Moments by KAOPEN Regime
# ══════════════════════════════════════════════════════════════════════════

def compute_moments(df):
    """
    Compute distribution moments of residuals by KAOPEN tercile.
    Includes bootstrap CIs on key tail measures.
    """
    print("\n" + "=" * 70)
    print("STEP 2: CONDITIONAL MOMENTS BY KAOPEN REGIME")
    print("=" * 70)

    # KAOPEN terciles (same construction as phase 7)
    kaopen_valid = df['kaopen'].dropna()
    t1 = kaopen_valid.quantile(1/3)
    t2 = kaopen_valid.quantile(2/3)
    df['kaopen_tercile'] = pd.cut(df['kaopen'], bins=[-np.inf, t1, t2, np.inf],
                                   labels=['Closed', 'Mid', 'Open'])
    print(f"  KAOPEN cutoffs: ≤{t1:.2f} (Closed), {t1:.2f}–{t2:.2f} (Mid), ≥{t2:.2f} (Open)")

    rows = []
    tercile_resids = {}

    for tercile in ['Closed', 'Mid', 'Open']:
        sub = df[df['kaopen_tercile'] == tercile]
        r = sub['resid'].dropna().values
        tercile_resids[tercile] = r

        if len(r) < 20:
            print(f"  {tercile}: insufficient obs ({len(r)}), skipping")
            continue

        var5 = np.percentile(r, 5)
        var10 = np.percentile(r, 10)
        cvar5 = r[r <= var5].mean() if np.sum(r <= var5) > 0 else var5
        cvar10 = r[r <= var10].mean() if np.sum(r <= var10) > 0 else var10

        row = {
            'Tercile': tercile,
            'N': len(r),
            'Mean': np.mean(r),
            'SD': np.std(r, ddof=1),
            'Skewness': scipy_stats.skew(r),
            'Excess Kurtosis': scipy_stats.kurtosis(r),
            'VaR5': var5,
            'VaR10': var10,
            'CVaR5': cvar5,
            'CVaR10': cvar10,
        }

        # ── Bootstrap CIs (country-level cluster) ──
        countries = sub['iso3'].unique()
        boot_sd, boot_var5, boot_cvar5 = [], [], []

        for _ in range(N_BOOTSTRAP):
            boot_c = np.random.choice(countries, size=len(countries), replace=True)
            boot_r = []
            for c in boot_c:
                boot_r.append(sub.loc[sub['iso3'] == c, 'resid'].values)
            boot_r = np.concatenate(boot_r)

            if len(boot_r) < 10:
                continue
            boot_sd.append(np.std(boot_r, ddof=1))
            bv5 = np.percentile(boot_r, 5)
            boot_var5.append(bv5)
            boot_cvar5.append(boot_r[boot_r <= bv5].mean() if np.sum(boot_r <= bv5) > 0 else bv5)

        if len(boot_sd) > 50:
            row['SD_CI_lo'] = np.percentile(boot_sd, 2.5)
            row['SD_CI_hi'] = np.percentile(boot_sd, 97.5)
            row['VaR5_CI_lo'] = np.percentile(boot_var5, 2.5)
            row['VaR5_CI_hi'] = np.percentile(boot_var5, 97.5)
            row['CVaR5_CI_lo'] = np.percentile(boot_cvar5, 2.5)
            row['CVaR5_CI_hi'] = np.percentile(boot_cvar5, 97.5)

        rows.append(row)
        print(f"  {tercile}: N={len(r)}, SD={row['SD']:.3f}, VaR5={var5:.3f}, CVaR5={cvar5:.3f}")

    # ── Variance equality tests ──
    groups = [tercile_resids[t] for t in ['Closed', 'Mid', 'Open']
              if t in tercile_resids and len(tercile_resids[t]) >= 20]

    levene_stat, levene_p = scipy_stats.levene(*groups) if len(groups) >= 2 else (np.nan, np.nan)
    bf_stat, bf_p = scipy_stats.levene(*groups, center='median') if len(groups) >= 2 else (np.nan, np.nan)

    print(f"\n  Levene test:         F={levene_stat:.3f}, p={levene_p:.4f}")
    print(f"  Brown-Forsythe test: F={bf_stat:.3f}, p={bf_p:.4f}")

    # ── Write table ──
    lines = ["# Residual Distribution Moments by KAOPEN Tercile\n"]
    lines.append("| Statistic | Closed | Mid | Open |")
    lines.append("|:---|---:|---:|---:|")

    stats_to_show = [
        ('N', 'N', '{:.0f}'),
        ('Mean', 'Mean', '{:.4f}'),
        ('SD', 'SD', '{:.4f}'),
        ('SD 95% CI', None, None),
        ('Skewness', 'Skewness', '{:.4f}'),
        ('Excess Kurtosis', 'Excess Kurtosis', '{:.4f}'),
        ('VaR5 (pp GDP)', 'VaR5', '{:.3f}'),
        ('VaR5 95% CI', None, None),
        ('VaR10 (pp GDP)', 'VaR10', '{:.3f}'),
        ('CVaR5 (pp GDP)', 'CVaR5', '{:.3f}'),
        ('CVaR5 95% CI', None, None),
        ('CVaR10 (pp GDP)', 'CVaR10', '{:.3f}'),
    ]

    row_map = {r['Tercile']: r for r in rows}

    ci_keys = {
        'SD 95% CI': ('SD_CI_lo', 'SD_CI_hi'),
        'VaR5 95% CI': ('VaR5_CI_lo', 'VaR5_CI_hi'),
        'CVaR5 95% CI': ('CVaR5_CI_lo', 'CVaR5_CI_hi'),
    }

    for label, key, fmt_str in stats_to_show:
        row_line = f"| {label} |"
        for t in ['Closed', 'Mid', 'Open']:
            if t not in row_map:
                row_line += " — |"
                continue
            r = row_map[t]
            if key is not None:
                row_line += f" {fmt_str.format(r[key])} |"
            elif label in ci_keys:
                lo_key, hi_key = ci_keys[label]
                if lo_key in r:
                    row_line += f" [{r[lo_key]:.3f}, {r[hi_key]:.3f}] |"
                else:
                    row_line += " — |"
        lines.append(row_line)

    lines.append("|:---|---:|---:|---:|")
    lines.append(f"| Levene F | | {levene_stat:.3f} | p={levene_p:.4f} |")
    lines.append(f"| Brown-Forsythe F | | {bf_stat:.3f} | p={bf_p:.4f} |")

    lines.append("\n*Residuals from PanelGLS: ca_gdp ~ Z₁ + Z₂ + Z₃ + fiscal_bal + NFA_lag + growth. "
                 "KAOPEN deliberately excluded from mean equation.*")
    lines.append("*VaR/CVaR are left-tail (worst CA outcomes). "
                 f"Bootstrap: {N_BOOTSTRAP} country-cluster iterations. 95% CIs in brackets.*")

    path = OUT_TABLES / "residual_moments.md"
    path.write_text('\n'.join(lines))
    print(f"\n  Saved: {path}")

    return df, rows, (levene_stat, levene_p, bf_stat, bf_p)


# ══════════════════════════════════════════════════════════════════════════
# Step 3: Quantile Regression
# ══════════════════════════════════════════════════════════════════════════

def quantile_regression(df):
    """
    Quantile regression at τ = {0.05, 0.10, 0.25, 0.50, 0.75, 0.90, 0.95}.
    Panel A: base model. Panel B: with KAOPEN interactions.
    """
    print("\n" + "=" * 70)
    print("STEP 3: QUANTILE REGRESSION")
    print("=" * 70)

    base_vars = ['Z_1', 'Z_2', 'Z_3', 'fiscal_bal_gdp', 'nfa_gdp_lag', 'rgdp_growth']
    interact_vars = base_vars + ['kaopen', 'Z_1_x_kaopen', 'Z_2_x_kaopen', 'Z_3_x_kaopen']

    # Ensure interaction terms exist
    for zv in ['Z_1', 'Z_2', 'Z_3']:
        iv = f'{zv}_x_kaopen'
        if iv not in df.columns:
            df[iv] = df[zv] * df['kaopen']

    quantiles = [0.05, 0.10, 0.25, 0.50, 0.75, 0.90, 0.95]

    results_a = {}  # Panel A
    results_b = {}  # Panel B

    # Panel A
    sub_a = df[['ca_gdp'] + base_vars].dropna()
    X_a = sm.add_constant(sub_a[base_vars].values)
    y_a = sub_a['ca_gdp'].values
    print(f"\n  Panel A: N={len(sub_a)}")

    for tau in quantiles:
        qr = sm.QuantReg(y_a, X_a).fit(q=tau, max_iter=1000)
        results_a[tau] = {
            'params': qr.params[1:],  # skip constant
            'pvalues': qr.pvalues[1:],
            'se': qr.bse[1:],
        }
        z1_idx = 0
        print(f"    τ={tau:.2f}: Z₁={qr.params[1]:.4f} (p={qr.pvalues[1]:.4f})")

    # Panel B
    sub_b = df[['ca_gdp'] + interact_vars].dropna()
    X_b = sm.add_constant(sub_b[interact_vars].values)
    y_b = sub_b['ca_gdp'].values
    print(f"\n  Panel B (interactions): N={len(sub_b)}")

    for tau in quantiles:
        qr = sm.QuantReg(y_b, X_b).fit(q=tau, max_iter=1000)
        results_b[tau] = {
            'params': qr.params[1:],
            'pvalues': qr.pvalues[1:],
            'se': qr.bse[1:],
        }
        # Z_1×kaopen is at index 6 (after base 6 vars + kaopen)
        z1k_idx = interact_vars.index('Z_1_x_kaopen')
        print(f"    τ={tau:.2f}: Z₁×KAOPEN={qr.params[1+z1k_idx]:.4f} "
              f"(p={qr.pvalues[1+z1k_idx]:.4f})")

    # ── Write table ──
    lines = ["# Quantile Regression: Demographics Across the CA/GDP Distribution\n"]

    # Panel A
    lines.append("## Panel A: Base Model\n")
    header = "| Variable | " + " | ".join([f"τ={t:.2f}" for t in quantiles]) + " |"
    sep = "|:---|" + "|".join(["---:" for _ in quantiles]) + "|"
    lines.append(header)
    lines.append(sep)

    for i, v in enumerate(base_vars):
        coef_row = f"| {v} |"
        se_row = "| |"
        for tau in quantiles:
            c = results_a[tau]['params'][i]
            s = results_a[tau]['se'][i]
            p = results_a[tau]['pvalues'][i]
            cv, sv = fmt(c, s, p)
            coef_row += f" {cv} |"
            se_row += f" {sv} |"
        lines.append(coef_row)
        lines.append(se_row)

    lines.append(f"\n*N = {len(sub_a)}. Pooled quantile regression (Koenker & Bassett).*\n")

    # Panel B
    lines.append("## Panel B: With KAOPEN Interactions\n")
    lines.append(header.replace("Base", "Interaction"))
    lines.append(sep)

    for i, v in enumerate(interact_vars):
        coef_row = f"| {v} |"
        se_row = "| |"
        for tau in quantiles:
            c = results_b[tau]['params'][i]
            s = results_b[tau]['se'][i]
            p = results_b[tau]['pvalues'][i]
            cv, sv = fmt(c, s, p)
            coef_row += f" {cv} |"
            se_row += f" {sv} |"
        lines.append(coef_row)
        lines.append(se_row)

    lines.append(f"\n*N = {len(sub_b)}. Pooled quantile regression with Z×KAOPEN interactions.*")
    lines.append("*\\*p<0.1, \\*\\*p<0.05, \\*\\*\\*p<0.01*")

    path = OUT_TABLES / "quantile_regression.md"
    path.write_text('\n'.join(lines))
    print(f"\n  Saved: {path}")

    return results_a, results_b


# ══════════════════════════════════════════════════════════════════════════
# Step 4: Risk Premium of Opening
# ══════════════════════════════════════════════════════════════════════════

def risk_premium(df, moments_rows, var_tests):
    """
    Test whether KAOPEN widens the residual distribution.
    Squared-residual and absolute-residual regressions via PanelGLS.
    """
    print("\n" + "=" * 70)
    print("STEP 4: RISK PREMIUM OF FINANCIAL OPENING")
    print("=" * 70)

    levene_stat, levene_p, bf_stat, bf_p = var_tests

    df['resid_sq'] = df['resid'] ** 2
    df['resid_abs'] = df['resid'].abs()

    controls = ['fiscal_bal_gdp', 'nfa_gdp_lag', 'rgdp_growth']
    results = []

    # (1) Squared residual ~ kaopen + controls
    x_vars = ['kaopen'] + controls
    sub = df[['resid_sq'] + x_vars + ['iso3', 'year']].dropna()
    if len(sub) >= 50:
        gls = PanelGLS()
        gls.fit(sub['resid_sq'].values, sub[x_vars].values,
                sub['iso3'].values, sub['year'].values)
        r = {'model': 'resid²', 'n_obs': gls.n_obs, 'n_countries': gls.n_countries,
             'r_squared': gls.r_squared}
        for i, v in enumerate(x_vars):
            r[f'{v}_coef'] = gls.beta[i]
            r[f'{v}_se'] = gls.se[i]
            r[f'{v}_p'] = gls.pvalues[i]
        results.append(r)
        print(f"  resid² ~ kaopen: coef={gls.beta[0]:.4f} ({gls.se[0]:.4f}) "
              f"p={gls.pvalues[0]:.4f} {stars(gls.pvalues[0])}")

    # (2) Absolute residual ~ kaopen + controls
    sub = df[['resid_abs'] + x_vars + ['iso3', 'year']].dropna()
    if len(sub) >= 50:
        gls = PanelGLS()
        gls.fit(sub['resid_abs'].values, sub[x_vars].values,
                sub['iso3'].values, sub['year'].values)
        r = {'model': '|resid|', 'n_obs': gls.n_obs, 'n_countries': gls.n_countries,
             'r_squared': gls.r_squared}
        for i, v in enumerate(x_vars):
            r[f'{v}_coef'] = gls.beta[i]
            r[f'{v}_se'] = gls.se[i]
            r[f'{v}_p'] = gls.pvalues[i]
        results.append(r)
        print(f"  |resid| ~ kaopen: coef={gls.beta[0]:.4f} ({gls.se[0]:.4f}) "
              f"p={gls.pvalues[0]:.4f} {stars(gls.pvalues[0])}")

    # ── Summary ratios ──
    row_map = {r['Tercile']: r for r in moments_rows}
    sd_ratio = var5_ratio = np.nan
    if 'Open' in row_map and 'Closed' in row_map:
        sd_ratio = row_map['Open']['SD'] / row_map['Closed']['SD']
        var5_ratio = row_map['Open']['VaR5'] / row_map['Closed']['VaR5']
        print(f"\n  SD ratio (Open/Closed):   {sd_ratio:.3f}")
        print(f"  VaR5 ratio (Open/Closed): {var5_ratio:.3f}")

    # ── Write table ──
    lines = ["# Risk Premium of Financial Opening\n"]

    # Panel A: Regressions
    lines.append("## Panel A: Volatility Regressions\n")
    if results:
        model_labels = [r['model'] for r in results]
        header = "| Variable | " + " | ".join(model_labels) + " |"
        sep = "|:---|" + "|".join(["---:" for _ in results]) + "|"
        lines.append(header)
        lines.append(sep)

        for v in x_vars:
            coef_row = f"| {v} |"
            se_row = "| |"
            for r in results:
                if f'{v}_coef' in r:
                    cv, sv = fmt(r[f'{v}_coef'], r[f'{v}_se'], r[f'{v}_p'])
                    coef_row += f" {cv} |"
                    se_row += f" {sv} |"
                else:
                    coef_row += " |"
                    se_row += " |"
            lines.append(coef_row)
            lines.append(se_row)

        lines.append("|:---|" + "|".join(["---:" for _ in results]) + "|")
        n_row = "| N |"
        r2_row = "| R² |"
        for r in results:
            n_row += f" {r['n_obs']} |"
            r2_row += f" {r['r_squared']:.4f} |"
        lines.append(n_row)
        lines.append(r2_row)

    # Panel B: Summary
    lines.append("\n## Panel B: Summary Statistics\n")
    lines.append("| Measure | Value |")
    lines.append("|:---|---:|")
    lines.append(f"| SD ratio (Open / Closed) | {sd_ratio:.3f} |")
    lines.append(f"| VaR5 ratio (Open / Closed) | {var5_ratio:.3f} |")
    lines.append(f"| Levene test (F) | {levene_stat:.3f} |")
    lines.append(f"| Levene p-value | {levene_p:.4f} |")
    lines.append(f"| Brown-Forsythe test (F) | {bf_stat:.3f} |")
    lines.append(f"| Brown-Forsythe p-value | {bf_p:.4f} |")

    lines.append("\n*Panel A: PanelGLS of squared/absolute residuals on KAOPEN + controls. "
                 "Positive KAOPEN coefficient = opening increases CA volatility.*")
    lines.append("*Panel B: Open/Closed ratios >1 indicate wider distributions under openness. "
                 "Levene/Brown-Forsythe test H₀: equal variances across KAOPEN terciles.*")
    lines.append("*\\*p<0.1, \\*\\*p<0.05, \\*\\*\\*p<0.01*")

    path = OUT_TABLES / "opening_risk_premium.md"
    path.write_text('\n'.join(lines))
    print(f"\n  Saved: {path}")


# ══════════════════════════════════════════════════════════════════════════
# Step 5: Risk Matrix (Demo Stage × KAOPEN Regime)
# ══════════════════════════════════════════════════════════════════════════

def risk_matrix(df):
    """
    3×3 cross-tabulation: demographic stage × KAOPEN tercile.
    Per cell: N, SD, VaR5, CVaR5, skewness.
    """
    print("\n" + "=" * 70)
    print("STEP 5: RISK MATRIX (DEMO STAGE × KAOPEN REGIME)")
    print("=" * 70)

    demo_stages = ['early', 'mid', 'late']
    kaopen_terciles = ['Closed', 'Mid', 'Open']

    # Build matrix
    cells = {}
    for ds in demo_stages:
        for kt in kaopen_terciles:
            sub = df[(df['demo_tercile'] == ds) & (df['kaopen_tercile'] == kt)]
            r = sub['resid'].dropna().values
            n = len(r)

            if n < 20:
                cells[(ds, kt)] = {'N': n, 'SD': np.nan, 'VaR5': np.nan,
                                   'CVaR5': np.nan, 'Skewness': np.nan}
                print(f"  {ds} × {kt}: N={n} (< 20, skipped)")
                continue

            var5 = np.percentile(r, 5)
            cvar5 = r[r <= var5].mean() if np.sum(r <= var5) > 0 else var5

            cells[(ds, kt)] = {
                'N': n,
                'SD': np.std(r, ddof=1),
                'VaR5': var5,
                'CVaR5': cvar5,
                'Skewness': scipy_stats.skew(r),
            }
            print(f"  {ds:6s} × {kt:6s}: N={n:5d}, SD={np.std(r, ddof=1):.3f}, "
                  f"VaR5={var5:.3f}, CVaR5={cvar5:.3f}")

    # ── Write table ──
    lines = ["# Tail Risk Matrix: Demographic Stage × Capital Account Openness\n"]

    for stat, label, fmt_str in [
        ('N', 'N (observations)', '{:.0f}'),
        ('SD', 'Residual SD (pp GDP)', '{:.3f}'),
        ('VaR5', 'VaR5 (pp GDP)', '{:.3f}'),
        ('CVaR5', 'CVaR5 (pp GDP)', '{:.3f}'),
        ('Skewness', 'Skewness', '{:.3f}'),
    ]:
        lines.append(f"\n## {label}\n")
        lines.append("| Demo Stage | Closed | Mid | Open |")
        lines.append("|:---|---:|---:|---:|")
        for ds in demo_stages:
            row = f"| {ds.capitalize()} |"
            for kt in kaopen_terciles:
                val = cells[(ds, kt)][stat]
                if np.isnan(val) if isinstance(val, float) else False:
                    row += " — |"
                else:
                    row += f" {fmt_str.format(val)} |"
            lines.append(row)

    lines.append("\n*Cells with N < 20 suppressed. VaR5/CVaR5 are left-tail (worst CA outcomes). "
                 "Demo stages from pooled demo_tercile (early/mid/late transition). "
                 "KAOPEN terciles from pooled distribution.*")

    path = OUT_TABLES / "risk_matrix.md"
    path.write_text('\n'.join(lines))
    print(f"\n  Saved: {path}")


# ══════════════════════════════════════════════════════════════════════════
# Step 6: Forward Vulnerability Rankings
# ══════════════════════════════════════════════════════════════════════════

def vulnerability_rankings(df):
    """
    For each country with recent data (≥2020): fitted CA from demographics +
    country-specific historical residual distribution → implied VaR5.
    """
    print("\n" + "=" * 70)
    print("STEP 6: FORWARD VULNERABILITY RANKINGS")
    print("=" * 70)

    # Country-level residual stats
    country_stats = df.groupby('iso3').agg(
        resid_mean=('resid', 'mean'),
        resid_sd=('resid', 'std'),
        resid_skew=('resid', lambda x: scipy_stats.skew(x.dropna()) if len(x.dropna()) >= 5 else np.nan),
        resid_p5=('resid', lambda x: np.percentile(x.dropna(), 5) if len(x.dropna()) >= 5 else np.nan),
        n_obs=('resid', 'count'),
    ).reset_index()

    # Recent data: most recent year ≥ 2020
    recent = df[df['year'] >= 2020].sort_values('year').groupby('iso3').last().reset_index()

    merged = recent.merge(country_stats, on='iso3', how='inner')
    merged = merged[merged['n_obs'] >= 5].copy()  # need enough history

    # Empirical VaR5
    merged['implied_var5_empirical'] = merged['fitted'] + merged['resid_p5']
    # Gaussian fallback
    merged['implied_var5_gaussian'] = merged['fitted'] - 1.645 * merged['resid_sd']

    # Use empirical where available, gaussian as fallback
    merged['implied_var5'] = np.where(
        merged['resid_p5'].notna(),
        merged['implied_var5_empirical'],
        merged['implied_var5_gaussian']
    )

    # Sort by vulnerability (most negative = most vulnerable)
    merged = merged.sort_values('implied_var5')

    print(f"  Countries with recent data: {len(merged)}")

    # ── Write table ──
    lines = ["# Forward CA/GDP Vulnerability Rankings\n"]

    # Top 25 most vulnerable
    lines.append("## Most Vulnerable (Top 25)\n")
    lines.append("| Rank | Country | Actual CA/GDP | Fitted CA | Resid SD | Skewness | "
                 "VaR5 | KAOPEN | Old Dep | Youth Dep |")
    lines.append("|---:|:---|---:|---:|---:|---:|---:|---:|---:|---:|")

    top25 = merged.head(25)
    for rank, (_, r) in enumerate(top25.iterrows(), 1):
        ca_val = f"{r['ca_gdp']:.2f}" if pd.notna(r['ca_gdp']) else "—"
        fit_val = f"{r['fitted']:.2f}" if pd.notna(r['fitted']) else "—"
        sd_val = f"{r['resid_sd']:.2f}" if pd.notna(r['resid_sd']) else "—"
        skew_val = f"{r['resid_skew']:.2f}" if pd.notna(r['resid_skew']) else "—"
        var5_val = f"{r['implied_var5']:.2f}" if pd.notna(r['implied_var5']) else "—"
        kao_val = f"{r['kaopen']:.2f}" if pd.notna(r.get('kaopen', np.nan)) else "—"
        old_val = f"{r['old_dep']:.1f}" if pd.notna(r.get('old_dep', np.nan)) else "—"
        youth_val = f"{r['youth_dep']:.1f}" if pd.notna(r.get('youth_dep', np.nan)) else "—"

        lines.append(f"| {rank} | {r['iso3']} | {ca_val} | {fit_val} | {sd_val} | "
                     f"{skew_val} | {var5_val} | {kao_val} | {old_val} | {youth_val} |")

    # Bottom 10 safest
    lines.append("\n## Safest (Bottom 10)\n")
    lines.append("| Rank | Country | Actual CA/GDP | Fitted CA | Resid SD | Skewness | "
                 "VaR5 | KAOPEN | Old Dep | Youth Dep |")
    lines.append("|---:|:---|---:|---:|---:|---:|---:|---:|---:|---:|")

    bottom10 = merged.tail(10).iloc[::-1]
    for rank, (_, r) in enumerate(bottom10.iterrows(), 1):
        ca_val = f"{r['ca_gdp']:.2f}" if pd.notna(r['ca_gdp']) else "—"
        fit_val = f"{r['fitted']:.2f}" if pd.notna(r['fitted']) else "—"
        sd_val = f"{r['resid_sd']:.2f}" if pd.notna(r['resid_sd']) else "—"
        skew_val = f"{r['resid_skew']:.2f}" if pd.notna(r['resid_skew']) else "—"
        var5_val = f"{r['implied_var5']:.2f}" if pd.notna(r['implied_var5']) else "—"
        kao_val = f"{r['kaopen']:.2f}" if pd.notna(r.get('kaopen', np.nan)) else "—"
        old_val = f"{r['old_dep']:.1f}" if pd.notna(r.get('old_dep', np.nan)) else "—"
        youth_val = f"{r['youth_dep']:.1f}" if pd.notna(r.get('youth_dep', np.nan)) else "—"

        lines.append(f"| {rank} | {r['iso3']} | {ca_val} | {fit_val} | {sd_val} | "
                     f"{skew_val} | {var5_val} | {kao_val} | {old_val} | {youth_val} |")

    lines.append("\n*VaR5 = fitted CA + country's empirical 5th percentile residual. "
                 "Gaussian fallback: fitted − 1.645 × SD. "
                 "Requires ≥5 historical observations.*")
    lines.append("*Rankings based on most recent year ≥ 2020. Lower VaR5 = more vulnerable.*")

    path = OUT_TABLES / "var_vulnerability.md"
    path.write_text('\n'.join(lines))
    print(f"\n  Saved: {path}")

    # Print highlights
    print(f"\n  Most vulnerable: {top25.iloc[0]['iso3']} (VaR5={top25.iloc[0]['implied_var5']:.2f})")
    safe = merged.tail(1).iloc[0]
    print(f"  Safest:          {safe['iso3']} (VaR5={safe['implied_var5']:.2f})")


# ══════════════════════════════════════════════════════════════════════════
# Step 7: Systemic / Portfolio Risk — Node Closure Counterfactuals
# ══════════════════════════════════════════════════════════════════════════

def systemic_risk(df):
    """
    Portfolio approach: how concentrated is the global CA system, and what
    happens if large nodes close? USD-weighted residual variance decomposition
    and counterfactual closure scenarios.
    """
    print("\n" + "=" * 70)
    print("STEP 7: SYSTEMIC / PORTFOLIO RISK")
    print("=" * 70)

    # Need ca_usd and ngdp_usd
    if 'ca_usd' not in df.columns or 'ngdp_usd' not in df.columns:
        print("  Missing ca_usd or ngdp_usd — skipping systemic analysis")
        return

    # USD-weighted residuals: resid (in pp) × GDP = residual in $bn
    df['resid_usd'] = df['resid'] / 100 * df['ngdp_usd']

    # ── A. Variance concentration ──
    # Use recent period for cross-section relevance
    recent = df[df['year'] >= 2015].copy()

    country_var = recent.groupby('iso3').agg(
        resid_usd_var=('resid_usd', 'var'),
        resid_sd_pct=('resid', 'std'),
        mean_gdp=('ngdp_usd', 'mean'),
        mean_ca_usd=('ca_usd', 'mean'),
        n=('resid_usd', 'count'),
    ).dropna()
    country_var = country_var[country_var['n'] >= 3].copy()

    total_var = country_var['resid_usd_var'].sum()
    country_var['var_share'] = country_var['resid_usd_var'] / total_var * 100
    country_var = country_var.sort_values('var_share', ascending=False)

    hhi = (country_var['var_share'] ** 2).sum() / 10000
    eff_n = 1 / hhi if hhi > 0 else np.nan

    print(f"  Variance HHI: {hhi*10000:.0f}")
    print(f"  Effective number of countries: {eff_n:.1f}")
    print(f"\n  Top 10 contributors to USD-weighted residual variance:")
    cum = 0
    top15 = country_var.head(15)
    for iso3, r in top15.iterrows():
        cum += r['var_share']
        print(f"    {iso3:4s}: SD={r['resid_sd_pct']:.2f}pp, "
              f"GDP=${r['mean_gdp']:.0f}bn, CA=${r['mean_ca_usd']:+.0f}bn, "
              f"VarShare={r['var_share']:.1f}%, Cum={cum:.1f}%")

    # ── B. Flow concentration (cross-section, most recent year) ──
    latest_year = recent['year'].max()
    cross = recent[recent['year'] == latest_year].copy()
    cross = cross[cross['ca_usd'].notna()].copy()

    total_abs_ca = cross['ca_usd'].abs().sum()
    cross['ca_abs_share'] = cross['ca_usd'].abs() / total_abs_ca * 100
    surplus = cross[cross['ca_usd'] > 0].sort_values('ca_usd', ascending=False)
    deficit = cross[cross['ca_usd'] < 0].sort_values('ca_usd')
    total_surplus = surplus['ca_usd'].sum()
    total_deficit = deficit['ca_usd'].sum()

    # US share of deficits
    us_row = cross[cross['iso3'] == 'USA']
    us_deficit = us_row['ca_usd'].values[0] if len(us_row) > 0 else 0
    us_deficit_share = abs(us_deficit) / abs(total_deficit) * 100 if total_deficit != 0 else 0

    # Big 3 surplus share
    big3 = ['CHN', 'DEU', 'JPN']
    big3_surplus = surplus[surplus['iso3'].isin(big3)]['ca_usd'].sum()
    big3_share = big3_surplus / total_surplus * 100 if total_surplus > 0 else 0

    print(f"\n  Flow concentration ({latest_year}):")
    print(f"    Total absolute CA: ${total_abs_ca:.0f}bn")
    print(f"    US deficit: ${us_deficit:.0f}bn ({us_deficit_share:.1f}% of global deficits)")
    print(f"    CHN+DEU+JPN surplus: ${big3_surplus:.0f}bn ({big3_share:.1f}% of global surpluses)")

    # Top 5 surplus / deficit concentration
    top5_surplus = surplus.head(5)['ca_usd'].sum()
    top5_deficit = deficit.head(5)['ca_usd'].sum()
    print(f"    Top 5 surplus share: {top5_surplus/total_surplus*100:.1f}%")
    print(f"    Top 5 deficit share: {abs(top5_deficit)/abs(total_deficit)*100:.1f}%")

    # ── C. Counterfactual closure scenarios ──
    # For each scenario: if a node closes, compute stranded flows
    scenarios = []

    # Scenario 1: US closes (largest importer)
    if total_deficit != 0:
        us_share_frac = abs(us_deficit) / abs(total_deficit)
        stranded_surplus = []
        for _, r in surplus.head(10).iterrows():
            stranded_surplus.append({
                'iso3': r['iso3'],
                'surplus': r['ca_usd'],
                'stranded': r['ca_usd'] * us_share_frac,
            })
        scenarios.append({
            'name': 'US closure',
            'flow_removed': abs(us_deficit),
            'share_of_system': abs(us_deficit) / total_abs_ca * 100,
            'stranded': pd.DataFrame(stranded_surplus),
        })

    # Scenario 2: CHN+DEU+JPN close (largest exporters)
    if total_surplus > 0:
        big3_share_frac = big3_surplus / total_surplus
        stranded_deficit = []
        for _, r in deficit.head(10).iterrows():
            stranded_deficit.append({
                'iso3': r['iso3'],
                'deficit': r['ca_usd'],
                'lost_financing': abs(r['ca_usd']) * big3_share_frac,
            })
        scenarios.append({
            'name': 'CHN+DEU+JPN closure',
            'flow_removed': big3_surplus,
            'share_of_system': big3_surplus / total_abs_ca * 100,
            'stranded': pd.DataFrame(stranded_deficit),
        })

    # Scenario 3: DEU closure alone (largest single surplus)
    deu_row = cross[cross['iso3'] == 'DEU']
    if len(deu_row) > 0:
        deu_surplus = deu_row['ca_usd'].values[0]
        deu_share_frac = deu_surplus / total_surplus
        stranded_deficit_deu = []
        for _, r in deficit.head(10).iterrows():
            stranded_deficit_deu.append({
                'iso3': r['iso3'],
                'deficit': r['ca_usd'],
                'lost_financing': abs(r['ca_usd']) * deu_share_frac,
            })
        scenarios.append({
            'name': 'DEU closure',
            'flow_removed': deu_surplus,
            'share_of_system': deu_surplus / total_abs_ca * 100,
            'stranded': pd.DataFrame(stranded_deficit_deu),
        })

    # ── D. System-level aggregate residual ──
    yearly_agg = recent.groupby('year').agg(
        total_resid_usd=('resid_usd', 'sum'),
        n_countries=('iso3', 'nunique'),
    ).reset_index()
    sys_sd = yearly_agg['total_resid_usd'].std()
    sys_var5 = yearly_agg['total_resid_usd'].quantile(0.05) if len(yearly_agg) >= 5 else np.nan

    print(f"\n  System-level aggregate residual (2015–{latest_year}):")
    print(f"    SD: ${sys_sd:.0f}bn")
    if not np.isnan(sys_var5):
        print(f"    VaR5: ${sys_var5:.0f}bn")

    # ── Write table ──
    lines = ["# Systemic Capital Flow Risk: Concentration and Closure Counterfactuals\n"]

    # Panel A: Variance concentration
    lines.append("## Panel A: USD-Weighted Residual Variance Concentration\n")
    lines.append("| Rank | Country | Resid SD (pp) | GDP ($bn) | CA ($bn) | Var Share (%) | Cum (%) |")
    lines.append("|---:|:---|---:|---:|---:|---:|---:|")
    cum = 0
    for rank, (iso3, r) in enumerate(top15.iterrows(), 1):
        cum += r['var_share']
        lines.append(f"| {rank} | {iso3} | {r['resid_sd_pct']:.2f} | "
                     f"{r['mean_gdp']:.0f} | {r['mean_ca_usd']:+.0f} | "
                     f"{r['var_share']:.1f} | {cum:.1f} |")
    lines.append(f"\n| | **Effective N** | | | | **{eff_n:.1f}** | HHI={hhi*10000:.0f} |")

    # Panel B: Flow concentration
    lines.append(f"\n## Panel B: Flow Concentration ({latest_year})\n")
    lines.append("| Metric | Value |")
    lines.append("|:---|---:|")
    lines.append(f"| Total absolute CA flows | ${total_abs_ca:.0f}bn |")
    lines.append(f"| Total surpluses | ${total_surplus:.0f}bn |")
    lines.append(f"| Total deficits | ${total_deficit:.0f}bn |")
    lines.append(f"| US share of global deficits | {us_deficit_share:.1f}% |")
    lines.append(f"| US share of absolute CA flows | {abs(us_deficit)/total_abs_ca*100:.1f}% |")
    lines.append(f"| CHN+DEU+JPN share of global surpluses | {big3_share:.1f}% |")
    lines.append(f"| Top 5 surplus share | {top5_surplus/total_surplus*100:.1f}% |")
    lines.append(f"| Top 5 deficit share | {abs(top5_deficit)/abs(total_deficit)*100:.1f}% |")

    # Panel C: Counterfactual scenarios
    lines.append("\n## Panel C: Closure Counterfactuals\n")
    for sc in scenarios:
        lines.append(f"### {sc['name'].title()}\n")
        lines.append(f"*Flow removed: ${sc['flow_removed']:.0f}bn "
                     f"({sc['share_of_system']:.1f}% of absolute CA flows)*\n")
        sdf = sc['stranded']
        if 'stranded' in sdf.columns:
            lines.append("| Surplus Country | Surplus ($bn) | Stranded ($bn) |")
            lines.append("|:---|---:|---:|")
            for _, r in sdf.iterrows():
                lines.append(f"| {r['iso3']} | {r['surplus']:.0f} | {r['stranded']:.0f} |")
        elif 'lost_financing' in sdf.columns:
            lines.append("| Deficit Country | Deficit ($bn) | Lost Financing ($bn) |")
            lines.append("|:---|---:|---:|")
            for _, r in sdf.iterrows():
                lines.append(f"| {r['iso3']} | {r['deficit']:.0f} | {r['lost_financing']:.0f} |")
        lines.append("")

    # Panel D: System aggregate
    lines.append("## Panel D: System-Level Aggregate\n")
    lines.append("| Year | Aggregate Residual ($bn) | N Countries |")
    lines.append("|---:|---:|---:|")
    for _, r in yearly_agg.iterrows():
        lines.append(f"| {r['year']:.0f} | {r['total_resid_usd']:.0f} | {r['n_countries']:.0f} |")
    lines.append(f"\n| **System SD** | **${sys_sd:.0f}bn** | |")

    lines.append("\n*Panel A: Residual variance from PanelGLS (ca_gdp ~ Z + controls, excl KAOPEN), "
                 "weighted by GDP² to convert to USD. 2015–present. "
                 "Effective N = 1/HHI of variance shares.*")
    lines.append("*Panel B: Cross-section flow concentration. "
                 "Panel C: Proportional reallocation assumption — if node closes, "
                 "its counterparties lose financing pro rata. "
                 "Panel D: Sum of USD-weighted residuals across all countries per year.*")

    path = OUT_TABLES / "systemic_risk.md"
    path.write_text('\n'.join(lines))
    print(f"\n  Saved: {path}")


# ══════════════════════════════════════════════════════════════════════════
# Main
# ══════════════════════════════════════════════════════════════════════════

def main():
    print("=" * 70)
    print("PHASE 9: CAPITAL FLOW RISK — THE RISK MANAGER'S PERSPECTIVE")
    print("=" * 70)

    df = pd.read_csv(DATA / "crises_panel.csv")
    print(f"Panel: {len(df)} obs, {df['iso3'].nunique()} countries")

    # Step 1: Extract residuals
    rdf, gls_model, x_vars = extract_residuals(df)

    # Step 2: Conditional moments by KAOPEN regime
    rdf, moments_rows, var_tests = compute_moments(rdf)

    # Step 3: Quantile regression
    qr_a, qr_b = quantile_regression(rdf)

    # Step 4: Risk premium of opening
    risk_premium(rdf, moments_rows, var_tests)

    # Step 5: Risk matrix
    risk_matrix(rdf)

    # Step 6: Forward vulnerability rankings
    vulnerability_rankings(rdf)

    # Step 7: Systemic / portfolio risk
    systemic_risk(rdf)

    print("\n" + "=" * 70)
    print("PHASE 9 COMPLETE — 6 output tables written")
    print("=" * 70)


if __name__ == '__main__':
    main()
